{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "from __future__ import (print_function, absolute_import)\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import keras\n",
    "import keras.backend as K\n",
    "from keras import datasets\n",
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "from keras.models import load_model\n",
    "\n",
    "from models import CNN, VGG8\n",
    "from wide_resnet import WideResidualNetwork\n",
    "from pulearn.utils import synthesize_pu_labels\n",
    "\n",
    "from keras.callbacks import (\n",
    "    ReduceLROnPlateau,\n",
    "    LearningRateScheduler,\n",
    "    CSVLogger,\n",
    "    EarlyStopping,\n",
    "    ModelCheckpoint)\n",
    "\n",
    "from keras_tqdm import TQDMNotebookCallback\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "\n",
    "global _SESSION\n",
    "config = tf.ConfigProto(allow_soft_placement=True)\n",
    "config.gpu_options.allow_growth = True\n",
    "_SESSION = tf.Session(config=config)\n",
    "K.set_session(_SESSION)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "# CIFAR110"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "num_classes = 11\n",
    "batch_size = 128\n",
    "epochs = 200\n",
    "data_augmentation = True\n",
    "checkpoint = None\n",
    "# checkpoint = 'model_checkpoint_cifar10_wide_resnet.h5'\n",
    "title = 'cifar110_vgg8'\n",
    "# title = 'cifar110_wide_resnet'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "(x_train_10, y_train_10), (x_test_10, y_test_10) = datasets.cifar10.load_data()\n",
    "(x_train_100, y_train_100), (x_test_100, y_test_100) = datasets.cifar100.load_data()\n",
    "\n",
    "y_train_10 = y_train_10 + 1\n",
    "y_test_10 = y_test_10 + 1\n",
    "y_train_100[...] = 0\n",
    "y_test_100[...] = 0\n",
    "\n",
    "x_train = np.concatenate((x_train_10, x_train_100), axis=0).astype('float32')\n",
    "y_train = np.concatenate((y_train_10, y_train_100), axis=0).astype('int8')\n",
    "x_test = np.concatenate((x_test_10, x_test_100), axis=0).astype('float32')\n",
    "y_test = np.concatenate((y_test_10, y_test_100), axis=0).astype('int8')\n",
    "\n",
    "x_train = x_train.astype(K.floatx())\n",
    "x_test = x_test.astype(K.floatx())\n",
    "\n",
    "y_train = keras.utils.to_categorical(y_train, num_classes)\n",
    "y_test = keras.utils.to_categorical(y_test, num_classes)\n",
    "\n",
    "def normalize(x):\n",
    "    \"\"\"Substract mean and Divide by std.\"\"\"\n",
    "    x -= np.array([125.3, 123.0, 113.9], dtype=K.floatx())\n",
    "    x /= np.array([63.0, 62.1, 66.7], dtype=K.floatx())\n",
    "    return x\n",
    "\n",
    "\n",
    "x_train = normalize(x_train)\n",
    "x_test = normalize(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(100000, 32, 32, 3) (100000, 11)\n",
      "(20000, 32, 32, 3) (20000, 11)\n"
     ]
    }
   ],
   "source": [
    "print(x_train.shape, y_train.shape)\n",
    "print(x_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('Positive (pct_missing=0.0):', 50000, ' -> ', 50000)\n",
      "('Positive (pct_missing=0.1):', 50000, ' -> ', 44971)\n",
      "('Positive (pct_missing=0.2):', 50000, ' -> ', 40026)\n",
      "('Positive (pct_missing=0.3):', 50000, ' -> ', 35185)\n",
      "('Positive (pct_missing=0.4):', 50000, ' -> ', 29980)\n",
      "('Positive (pct_missing=0.5):', 50000, ' -> ', 24759)\n",
      "('Positive (pct_missing=0.6):', 50000, ' -> ', 19982)\n",
      "('Positive (pct_missing=0.7):', 50000, ' -> ', 15187)\n",
      "('Positive (pct_missing=0.8):', 50000, ' -> ', 9823)\n",
      "('Positive (pct_missing=0.9):', 50000, ' -> ', 4970)\n",
      "('Positive (pct_missing=1.0):', 50000, ' -> ', 0)\n"
     ]
    }
   ],
   "source": [
    "pct_missing = 0.5\n",
    "y_train_pu = synthesize_pu_labels(y_train, random_state=None, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_13 (Conv2D)           (None, 32, 32, 32)        896       \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_13 (LeakyReLU)   (None, 32, 32, 32)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_14 (Conv2D)           (None, 32, 32, 32)        9248      \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_14 (LeakyReLU)   (None, 32, 32, 32)        0         \n",
      "_________________________________________________________________\n",
      "max_pooling2d_7 (MaxPooling2 (None, 16, 16, 32)        0         \n",
      "_________________________________________________________________\n",
      "dropout_9 (Dropout)          (None, 16, 16, 32)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_15 (Conv2D)           (None, 16, 16, 64)        18496     \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_15 (LeakyReLU)   (None, 16, 16, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_16 (Conv2D)           (None, 16, 16, 64)        36928     \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_16 (LeakyReLU)   (None, 16, 16, 64)        0         \n",
      "_________________________________________________________________\n",
      "max_pooling2d_8 (MaxPooling2 (None, 8, 8, 64)          0         \n",
      "_________________________________________________________________\n",
      "dropout_10 (Dropout)         (None, 8, 8, 64)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_17 (Conv2D)           (None, 8, 8, 128)         73856     \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_17 (LeakyReLU)   (None, 8, 8, 128)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_18 (Conv2D)           (None, 8, 8, 128)         147584    \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_18 (LeakyReLU)   (None, 8, 8, 128)         0         \n",
      "_________________________________________________________________\n",
      "max_pooling2d_9 (MaxPooling2 (None, 4, 4, 128)         0         \n",
      "_________________________________________________________________\n",
      "dropout_11 (Dropout)         (None, 4, 4, 128)         0         \n",
      "_________________________________________________________________\n",
      "flatten_3 (Flatten)          (None, 2048)              0         \n",
      "_________________________________________________________________\n",
      "dense_5 (Dense)              (None, 512)               1049088   \n",
      "_________________________________________________________________\n",
      "activation_5 (Activation)    (None, 512)               0         \n",
      "_________________________________________________________________\n",
      "dropout_12 (Dropout)         (None, 512)               0         \n",
      "_________________________________________________________________\n",
      "dense_6 (Dense)              (None, 11)                5643      \n",
      "_________________________________________________________________\n",
      "activation_6 (Activation)    (None, 11)                0         \n",
      "=================================================================\n",
      "Total params: 1,341,739\n",
      "Trainable params: 1,341,739\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "if title == 'cifar110_vgg8':\n",
    "    model = VGG8(input_shape=x_train.shape[1:], num_classes=num_classes)\n",
    "elif title == 'cifar110_wide_resnet':\n",
    "    model = WideResidualNetwork(depth=28, width=8, dropout_rate=0.5,\n",
    "                                classes=num_classes, include_top=True,\n",
    "                                weights=None)\n",
    "if checkpoint is not None:\n",
    "    model = load_model(checkpoint)\n",
    "    \n",
    "optimizer = keras.optimizers.Adam(lr=1e-4)\n",
    "model.compile(loss='categorical_crossentropy',\n",
    "              optimizer=optimizer,\n",
    "              metrics=['accuracy'])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Checkpoint\n",
    "checkpointer = ModelCheckpoint(\n",
    "    monitor='val_acc',\n",
    "    filepath=\"model_checkpoint_{}_pu_{}.h5\".format(title, pct_missing),\n",
    "    verbose=1,\n",
    "    save_best_only=True)\n",
    "\n",
    "# csvlogger\n",
    "csv_logger = CSVLogger(\n",
    "    'csv_logger_{}_pu_{}.csv'.format(title, pct_missing))\n",
    "\n",
    "# EarlyStopping\n",
    "early_stopper = EarlyStopping(monitor='val_acc',\n",
    "                              min_delta=0.001,\n",
    "                              patience=50)\n",
    "# Reduce lr with schedule\n",
    "# def schedule(epoch):\n",
    "#     lr = K.get_value(model.optimizer.lr)\n",
    "#     if epoch in [60, 120, 160]:\n",
    "#         lr = lr * np.sqrt(0.1)\n",
    "#     return lr\n",
    "# lr_scheduler = LearningRateScheduler(schedule)\n",
    "\n",
    "# Reduce lr on plateau\n",
    "lr_reducer = ReduceLROnPlateau(monitor='val_acc',\n",
    "                               factor=np.sqrt(0.1),\n",
    "                               cooldown=0,\n",
    "                               patience=20,\n",
    "                               min_lr=1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using real-time data augmentation.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f0b294931a58408a9d3e6cda9d94933d"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b1f0d3029e34387a8a333bbe6be3204"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/200\n",
      " 56/781 [=>............................] - ETA: 23s - loss: 1.4638 - acc: 0.7241 \b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b"
     ]
    }
   ],
   "source": [
    "if not data_augmentation:\n",
    "    print('No data augmentation applied.')\n",
    "    model.fit(x_train, y_train,\n",
    "              batch_size=batch_size,\n",
    "              epochs=epochs,\n",
    "              validation_data=(x_test, y_test),\n",
    "              class_weight=class_weight,\n",
    "              shuffle=True,\n",
    "              callbacks=[csv_logger, checkpointer, early_stopper, lr_reducer]) \n",
    "else:\n",
    "    print('Using real-time data augmentation.')\n",
    "    # This will do preprocessing and realtime data augmentation:\n",
    "    datagen = ImageDataGenerator(\n",
    "        featurewise_center=False,  # set input mean to 0 over the dataset\n",
    "        samplewise_center=False,  # set each sample mean to 0\n",
    "        featurewise_std_normalization=False,  # divide inputs by std\n",
    "        samplewise_std_normalization=False,  # divide each input by its std\n",
    "        zca_whitening=False,  # apply ZCA whitening\n",
    "        # randomly rotate images in the range (degrees, 0 to 180)\n",
    "        rotation_range=0,\n",
    "        # randomly shift images horizontally (fraction of total width)\n",
    "        width_shift_range=0.1,\n",
    "        # randomly shift images vertically (fraction of total height)\n",
    "        height_shift_range=0.1,\n",
    "        horizontal_flip=True,  # randomly flip images\n",
    "        vertical_flip=False)  # randomly flip images\n",
    "\n",
    "    # Compute quantities required for feature-wise normalization\n",
    "    # (std, mean, and principal components if ZCA whitening is applied).\n",
    "    datagen.fit(x_train)\n",
    "\n",
    "    # Fit the model on the batches generated by datagen.flow().\n",
    "    model.fit_generator(datagen.flow(x_train, y_train_pu[pct_missing],\n",
    "                                     batch_size=batch_size),\n",
    "                        steps_per_epoch=x_train.shape[0] // batch_size,\n",
    "                        epochs=epochs,\n",
    "                        validation_data=(x_test, y_test),\n",
    "                        callbacks=[checkpointer, early_stopper,\n",
    "                                   lr_reducer, csv_logger,\n",
    "                                   TQDMNotebookCallback()])\n",
    "    model.save('{}_pu_{}.h5'.format(title, pct_missing))\n",
    "    \n",
    "    # Predict\n",
    "    y_train_pred_enc = model.predict(x_train)\n",
    "    y_test_pred_enc = model.predict(x_test)\n",
    "    \n",
    "    lst = []\n",
    "    y_train_pred = np.argmax(y_train_pred_enc, axis=-1)\n",
    "    y_train_true = np.argmax(y_train, axis=-1)\n",
    "    y_train_label = np.argmax(y_train_pu[pct_missing], axis=-1)\n",
    "    for y_pred, y_true, y_label in zip(y_train_pred, y_train_true, y_train_label):\n",
    "        lst.append(dict(y_pred=y_pred, y_true=y_true, y_label=y_label, split='train'))\n",
    "\n",
    "    y_test_pred = np.argmax(y_test_pred_enc, axis=-1)\n",
    "    y_test_true = np.argmax(y_test, axis=-1)\n",
    "    for y_pred, y_true in zip(y_test_pred, y_test_true):\n",
    "        lst.append(dict(y_pred=y_pred, y_true=y_true, y_label=y_true, split='test'))\n",
    "    \n",
    "    df = pd.DataFrame(lst)\n",
    "    df.to_csv('prediction_{}_pu_{}.csv'.format(title, pct_missing), index=False, mode='a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
